{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Intervention Demo\n",
    "\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/safety-research/circuit-tracer/blob/main/demos/intervention_demo.ipynb\">\n",
    "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>\n",
    "\n",
    "In this demo, you'll learn how to perform interventions on models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "#@title Colab Setup Environment\n",
    "\n",
    "try:\n",
    "    import google.colab\n",
    "    !mkdir -p repository && cd repository && \\\n",
    "     git clone https://github.com/safety-research/circuit-tracer && \\\n",
    "     curl -LsSf https://astral.sh/uv/install.sh | sh && \\\n",
    "     uv pip install -e circuit-tracer/\n",
    "\n",
    "    import sys\n",
    "    from huggingface_hub import notebook_login\n",
    "    sys.path.append('repository/circuit-tracer')\n",
    "    sys.path.append('repository/circuit-tracer/demos')\n",
    "    notebook_login(new_session=False)\n",
    "    IN_COLAB = True\n",
    "except ImportError:\n",
    "    IN_COLAB = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import namedtuple\n",
    "from functools import partial\n",
    "\n",
    "import torch \n",
    "\n",
    "from circuit_tracer import ReplacementModel\n",
    "\n",
    "# display functions\n",
    "from utils import display_topk_token_predictions, display_generations_comparison"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we load our models (see `attribute_demo.ipynb` for more details)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dd9174b8e151411cae1cb9f2f2463edf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fc431133d4424e84a339b6e17df6444d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model google/gemma-2-2b into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "model = ReplacementModel.from_pretrained(\"google/gemma-2-2b\", \"gemma\", dtype=torch.bfloat16)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll write some helper functions to print the top next tokens of our model, and a class to store features in."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "Feature = namedtuple('Feature', ['layer', 'pos', 'feature_idx'])\n",
    "\n",
    "# a display function that needs the model's tokenizer\n",
    "display_topk_token_predictions = partial(display_topk_token_predictions, tokenizer=model.tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: Changing languages with zero ablations\n",
    "\n",
    "Imagine that you have the following [annotated attribution graph](https://www.neuronpedia.org/gemma-2-2b/graph?slug=gemma-michael-jordan-es&clerps=%5B%5B%222308855%22%2C%22sports%22%5D%2C%5B%222502222%22%2C%22Spanish+articles%22%5D%2C%5B%222513416%22%2C%22Spanish%22%5D%2C%5B%222509334%22%2C%22Spanish%22%5D%2C%5B%222413490%22%2C%22Spanish%22%5D%2C%5B%222403018%22%2C%22Spanish%22%5D%2C%5B%222407980%22%2C%22Spanish+articles%22%5D%2C%5B%222511463%22%2C%22Spanish%22%5D%2C%5B%222104818%22%2C%22basketball%22%5D%2C%5B%222109324%22%2C%22sports%22%5D%2C%5B%222009090%22%2C%22basketball%22%5D%2C%5B%221712431%22%2C%22sports%22%5D%2C%5B%221515208%22%2C%22play%22%5D%2C%5B%22401305%22%2C%22game%22%5D%2C%5B%22109339%22%2C%22a%2Fal+in+Spanish%22%5D%2C%5B%2213978%22%2C%22romance+languages%22%5D%2C%5B%2215822%22%2C%22romance+languages%22%5D%2C%5B%221404939%22%2C%22play%22%5D%2C%5B%221915763%22%2C%22sports%22%5D%2C%5B%221812672%22%2C%22basketball%22%5D%2C%5B%221414510%22%2C%22sports%22%5D%2C%5B%22401742%22%2C%22basketball%22%5D%2C%5B%22101173%22%2C%22basketball%22%5D%2C%5B%22411%22%2C%22famous+people+%2F+named+entities%22%5D%2C%5B%222000341%22%2C%22Spanish%22%5D%2C%5B%222303604%22%2C%22sports+%2F+table+tennis+%2F+pool+%22%5D%2C%5B%222413277%22%2C%22%28incomprehensible%29%22%5D%5D&pinnedIds=27_143831_6%2C25_13416_6%2C24_3018_6%2C25_9334_6%2C24_13490_6%2C25_2222_6%2C24_7980_6%2C25_11463_6%2C21_9324_6%2C21_4818_6%2C23_8855_6%2C20_9090_6%2C17_12431_6%2C15_15208_6%2C14_4939_6%2C4_1305_6%2C1_9339_6%2CE_113501_5%2C0_13978_5%2C0_15822_5%2CE_717_6%2C19_15763_6%2C18_12672_6%2C4_1742_4%2C14_14510_4%2C1_1173_4%2CE_18853_4%2CE_7939_3%2C0_411_4%2C20_341_6&supernodes=%5B%5B%22basketball%22%2C%2220_9090_6%22%2C%2218_12672_6%22%2C%2221_4818_6%22%2C%2223_8855_6%22%5D%2C%5B%22sports%22%2C%2217_12431_6%22%2C%2219_15763_6%22%2C%2221_9324_6%22%5D%2C%5B%22play%22%2C%224_1305_6%22%2C%2214_4939_6%22%2C%2215_15208_6%22%5D%2C%5B%22basketball%22%2C%224_1742_4%22%2C%221_1173_4%22%5D%2C%5B%22romance+language%22%2C%221_9339_6%22%2C%220_15822_5%22%2C%220_13978_5%22%5D%2C%5B%22Spanish%22%2C%2225_9334_6%22%2C%2225_13416_6%22%2C%2224_13490_6%22%2C%2224_7980_6%22%2C%2224_3018_6%22%2C%2225_2222_6%22%2C%2225_11463_6%22%2C%2220_341_6%22%5D%5D&clickedId=20_341_6) showing the circuit for the Spanish sentence *Hecho: Michael Jordan juega al*, or in English, *Fact: Michael Jordan plays*. The correct answer, which the model correctly predicts, is *baloncesto*, or *basketball*. We observe a supernode of features that correspond to the Spanish language. Can we intervene on these features to change the model's output?\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/safety-research/circuit-tracer/main/demos/img/gemma/mj-basketball-es.png\" width=\"400\">\n",
    "\n",
    "First, we can try to do this by identifying these supernode features, which we store below. For each, we store their layer, position (here, always -1, as all of these features are active at the final position), and feature ID. For the sake of convenience, we'll only add one supernode feature."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "supernode_features = [\n",
    "    Feature(layer=20,pos=-1,feature_idx=341),\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we need to turn our supernode features into a list of intervention tuples. These tuples are formatted as (layer, node, feature_idx, new_feature_value). For now, let's try just zeroing out these features at the last position."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "intervention_tuples = [(*supernode_feature, 0.0) for supernode_feature in supernode_features]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can run the intervention and view its effects on the model's output!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <style>\n",
       "    .token-viz {\n",
       "        font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;\n",
       "        margin-bottom: 10px;\n",
       "        max-width: 700px;\n",
       "    }\n",
       "    .token-viz .header {\n",
       "        font-weight: bold;\n",
       "        font-size: 14px;\n",
       "        margin-bottom: 3px;\n",
       "        padding: 4px 6px;\n",
       "        border-radius: 3px;\n",
       "        color: white;\n",
       "        display: inline-block;\n",
       "    }\n",
       "    .token-viz .sentence {\n",
       "        background-color: rgba(200, 200, 200, 0.2);\n",
       "        padding: 4px 6px;\n",
       "        border-radius: 3px;\n",
       "        border: 1px solid rgba(100, 100, 100, 0.5);\n",
       "        font-family: monospace;\n",
       "        margin-bottom: 8px;\n",
       "        font-weight: 500;\n",
       "        font-size: 14px;\n",
       "    }\n",
       "    .token-viz table {\n",
       "        width: 100%;\n",
       "        border-collapse: collapse;\n",
       "        margin-bottom: 8px;\n",
       "        font-size: 13px;\n",
       "        table-layout: fixed;\n",
       "    }\n",
       "    .token-viz th {\n",
       "        text-align: left;\n",
       "        padding: 4px 6px;\n",
       "        font-weight: bold;\n",
       "        border: 1px solid rgba(150, 150, 150, 0.5);\n",
       "        background-color: rgba(200, 200, 200, 0.3);\n",
       "    }\n",
       "    .token-viz td {\n",
       "        padding: 3px 6px;\n",
       "        border: 1px solid rgba(150, 150, 150, 0.5);\n",
       "        font-weight: 500;\n",
       "        overflow: hidden;\n",
       "        text-overflow: ellipsis;\n",
       "        white-space: nowrap;\n",
       "    }\n",
       "    .token-viz .token-col {\n",
       "        width: 20%;\n",
       "    }\n",
       "    .token-viz .prob-col {\n",
       "        width: 15%;\n",
       "    }\n",
       "    .token-viz .dist-col {\n",
       "        width: 65%;\n",
       "    }\n",
       "    .token-viz .monospace {\n",
       "        font-family: monospace;\n",
       "    }\n",
       "    .token-viz .bar-container {\n",
       "        display: flex;\n",
       "        align-items: center;\n",
       "    }\n",
       "    .token-viz .bar {\n",
       "        height: 12px;\n",
       "        min-width: 2px;\n",
       "    }\n",
       "    .token-viz .bar-text {\n",
       "        margin-left: 6px;\n",
       "        font-weight: 500;\n",
       "        font-size: 12px;\n",
       "    }\n",
       "    .token-viz .even-row {\n",
       "        background-color: rgba(240, 240, 240, 0.1);\n",
       "    }\n",
       "    .token-viz .odd-row {\n",
       "        background-color: rgba(255, 255, 255, 0.1);\n",
       "    }\n",
       "    </style>\n",
       "    \n",
       "    <div class=\"token-viz\">\n",
       "        <div class=\"header\" style=\"background-color: #555555;\">Input Sentence:</div>\n",
       "        <div class=\"sentence\">Hecho: Michael Jordan juega al</div>\n",
       "        \n",
       "        <div>\n",
       "            <div class=\"header\" style=\"background-color: #2471A3;\">Original Top 5 Tokens</div>\n",
       "            <table>\n",
       "                <thead>\n",
       "                    <tr>\n",
       "                        <th class=\"token-col\">Token</th>\n",
       "                        <th class=\"prob-col\" style=\"text-align: right;\">Probability</th>\n",
       "                        <th class=\"dist-col\">Distribution</th>\n",
       "                    </tr>\n",
       "                </thead>\n",
       "                <tbody>\n",
       "    \n",
       "                    <tr class=\"even-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" baloncesto\"> baloncesto</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.613</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #2471A3; width: 100%;\"></div>\n",
       "                                <span class=\"bar-text\">61.3%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"odd-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" golf\"> golf</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.050</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #2471A3; width: 8%;\"></div>\n",
       "                                <span class=\"bar-text\">5.0%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"even-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" fútbol\"> fútbol</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.044</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #2471A3; width: 7%;\"></div>\n",
       "                                <span class=\"bar-text\">4.4%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"odd-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" bás\"> bás</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.044</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #2471A3; width: 7%;\"></div>\n",
       "                                <span class=\"bar-text\">4.4%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"even-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" béisbol\"> béisbol</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.035</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #2471A3; width: 5%;\"></div>\n",
       "                                <span class=\"bar-text\">3.5%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                </tbody>\n",
       "            </table>\n",
       "            \n",
       "            <div class=\"header\" style=\"background-color: #27AE60;\">New Top 5 Tokens</div>\n",
       "            <table>\n",
       "                <thead>\n",
       "                    <tr>\n",
       "                        <th class=\"token-col\">Token</th>\n",
       "                        <th class=\"prob-col\" style=\"text-align: right;\">Probability</th>\n",
       "                        <th class=\"dist-col\">Distribution</th>\n",
       "                    </tr>\n",
       "                </thead>\n",
       "                <tbody>\n",
       "    \n",
       "                    <tr class=\"even-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" baloncesto\"> baloncesto</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.598</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #27AE60; width: 97%;\"></div>\n",
       "                                <span class=\"bar-text\">59.8%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"odd-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" golf\"> golf</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.063</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #27AE60; width: 10%;\"></div>\n",
       "                                <span class=\"bar-text\">6.3%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"even-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" bás\"> bás</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.056</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #27AE60; width: 9%;\"></div>\n",
       "                                <span class=\"bar-text\">5.6%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"odd-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" fútbol\"> fútbol</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.034</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #27AE60; width: 5%;\"></div>\n",
       "                                <span class=\"bar-text\">3.4%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"even-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" basketball\"> basketball</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.034</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #27AE60; width: 5%;\"></div>\n",
       "                                <span class=\"bar-text\">3.4%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                </tbody>\n",
       "            </table>\n",
       "        </div>\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "s = \"Hecho: Michael Jordan juega al\"\n",
    "\n",
    "with torch.inference_mode():\n",
    "    original_logits = model(s)\n",
    "    new_logits, _ = model.feature_intervention(s, intervention_tuples)\n",
    "\n",
    "display_topk_token_predictions(s, original_logits, new_logits)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That wasn't very effective! We do see that the probability of *basketball* has risen, bringing it into the top 5. But intervening on just one feature isn't enough to change the model's behavior dramatically; the rest of the distribution remains more or less the same. This is because many Spanish features contribute to our model's output language, while we have changed only one. Try changing more!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: Swapping languages by turning features on\n",
    "\n",
    "In the last example, we only turned Spanish features off, yielding text in English, which seems to be the model's default. But what if we wanted to swap to another language? \n",
    "Then we'd have to turn language features from that language on. Let's try this with another language, French. Here is the attribution graph for the analogous French sentence, [`Fait: Michael Jordan joue au` →`basket`](https://www.neuronpedia.org/gemma-2-2b/graph?slug=gemma-basket&clickedId=17_10566_2&clerps=%5B%5B%222308855%22%2C%22basketball%22%5D%2C%5B%222502222%22%2C%22Spanish+articles%22%5D%2C%5B%222513416%22%2C%22Spanish%22%5D%2C%5B%222104818%22%2C%22basketball%22%5D%2C%5B%222109324%22%2C%22sports%22%5D%2C%5B%222009090%22%2C%22basketball%22%5D%2C%5B%221712431%22%2C%22sports%22%5D%2C%5B%221515208%22%2C%22play%22%5D%2C%5B%22401305%22%2C%22game%22%5D%2C%5B%2213978%22%2C%22romance+languages%22%5D%2C%5B%2215822%22%2C%22romance+languages%22%5D%2C%5B%221404939%22%2C%22play%22%5D%2C%5B%221915763%22%2C%22sports%22%5D%2C%5B%221812672%22%2C%22basketball%22%5D%2C%5B%221414510%22%2C%22sports%22%5D%2C%5B%22401742%22%2C%22basketball%22%5D%2C%5B%22101173%22%2C%22basketball%22%5D%2C%5B%22411%22%2C%22famous+people+%2F+named+entities%22%5D%2C%5B%221710566%22%2C%22French%22%5D%5D&pinnedIds=27_12220_7%2CE_18853_5%2C21_4818_7%2C21_9324_7%2C23_3604_7%2C25_14882_7%2C24_15306_7%2C23_15317_7%2C20_9090_7%2C24_3329_7%2C19_15763_7%2C18_12672_7%2C17_12431_7%2C17_5253_7%2C15_15208_7%2C14_4939_7%2C6_7377_7%2CE_78224_6%2C4_1305_7%2C3_305_7%2C24_2086_7%2C24_3772_7%2C21_16354_7%2C20_1454_7%2C23_2592_7%2C22_10566_7%2C23_2554_7%2C17_10566_6%2C0_4076_6%2C14_14575_6%2C7_11689_6%2C4_1742_5%2C1_1173_5%2CE_7939_4&supernodes=%5B%5B%22game%2Fplay%22%2C%223_305_7%22%2C%224_1305_7%22%2C%226_7377_7%22%2C%2215_15208_7%22%2C%2214_4939_7%22%5D%2C%5B%22French%22%2C%220_4076_6%22%2C%227_11689_6%22%2C%2214_14575_6%22%2C%2217_10566_6%22%5D%2C%5B%22basketball%22%2C%2221_4818_7%22%2C%2218_12672_7%22%5D%2C%5B%22sports%22%2C%2217_12431_7%22%2C%2217_5253_7%22%2C%2221_9324_7%22%2C%2220_9090_7%22%2C%2219_15763_7%22%2C%2223_3604_7%22%2C%2223_15317_7%22%5D%2C%5B%22basketball%22%2C%224_1742_5%22%2C%221_1173_5%22%5D%2C%5B%22French%22%2C%2224_3329_7%22%2C%2221_16354_7%22%2C%2220_1454_7%22%2C%2223_2592_7%22%2C%2223_2554_7%22%2C%2224_2086_7%22%2C%2224_15306_7%22%2C%2225_14882_7%22%2C%2224_3772_7%22%2C%2222_10566_7%22%5D%5D). \n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/safety-research/circuit-tracer/main/demos/img/gemma/mj-basketball-fr.png\" width=\"400\">\n",
    "\n",
    "The answer to the French query is \"basket\". Can we change that to Spanish? We start by taking one relatively low-level French feature, that feeds into all of the others."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "french_supernode_features = [Feature(layer=20,pos=-1,feature_idx=1454)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But what should we set the values of the French supernode features to be? Ideally, we set them to some in-distribution values. To do this, we can get the activations of these nodes on the French input sentence. We'll get these as a sparse tensor, to save on memory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "s_spanish = \"Hecho: Michael Jordan juega al\"\n",
    "_, activations = model.get_activations(s_spanish, sparse=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we construct and perform the intervention! Each supernode_feature contains precisely the information needed to index into `activations`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "spanish_supernode_features = supernode_features  # from before\n",
    "fr_es_intervention_tuples = [(*supernode_feature, 0.0) for supernode_feature in french_supernode_features] \n",
    "fr_es_intervention_tuples+= [(*supernode_feature, 10*activations[supernode_feature]) for (supernode_feature) in spanish_supernode_features]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <style>\n",
       "    .token-viz {\n",
       "        font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;\n",
       "        margin-bottom: 10px;\n",
       "        max-width: 700px;\n",
       "    }\n",
       "    .token-viz .header {\n",
       "        font-weight: bold;\n",
       "        font-size: 14px;\n",
       "        margin-bottom: 3px;\n",
       "        padding: 4px 6px;\n",
       "        border-radius: 3px;\n",
       "        color: white;\n",
       "        display: inline-block;\n",
       "    }\n",
       "    .token-viz .sentence {\n",
       "        background-color: rgba(200, 200, 200, 0.2);\n",
       "        padding: 4px 6px;\n",
       "        border-radius: 3px;\n",
       "        border: 1px solid rgba(100, 100, 100, 0.5);\n",
       "        font-family: monospace;\n",
       "        margin-bottom: 8px;\n",
       "        font-weight: 500;\n",
       "        font-size: 14px;\n",
       "    }\n",
       "    .token-viz table {\n",
       "        width: 100%;\n",
       "        border-collapse: collapse;\n",
       "        margin-bottom: 8px;\n",
       "        font-size: 13px;\n",
       "        table-layout: fixed;\n",
       "    }\n",
       "    .token-viz th {\n",
       "        text-align: left;\n",
       "        padding: 4px 6px;\n",
       "        font-weight: bold;\n",
       "        border: 1px solid rgba(150, 150, 150, 0.5);\n",
       "        background-color: rgba(200, 200, 200, 0.3);\n",
       "    }\n",
       "    .token-viz td {\n",
       "        padding: 3px 6px;\n",
       "        border: 1px solid rgba(150, 150, 150, 0.5);\n",
       "        font-weight: 500;\n",
       "        overflow: hidden;\n",
       "        text-overflow: ellipsis;\n",
       "        white-space: nowrap;\n",
       "    }\n",
       "    .token-viz .token-col {\n",
       "        width: 20%;\n",
       "    }\n",
       "    .token-viz .prob-col {\n",
       "        width: 15%;\n",
       "    }\n",
       "    .token-viz .dist-col {\n",
       "        width: 65%;\n",
       "    }\n",
       "    .token-viz .monospace {\n",
       "        font-family: monospace;\n",
       "    }\n",
       "    .token-viz .bar-container {\n",
       "        display: flex;\n",
       "        align-items: center;\n",
       "    }\n",
       "    .token-viz .bar {\n",
       "        height: 12px;\n",
       "        min-width: 2px;\n",
       "    }\n",
       "    .token-viz .bar-text {\n",
       "        margin-left: 6px;\n",
       "        font-weight: 500;\n",
       "        font-size: 12px;\n",
       "    }\n",
       "    .token-viz .even-row {\n",
       "        background-color: rgba(240, 240, 240, 0.1);\n",
       "    }\n",
       "    .token-viz .odd-row {\n",
       "        background-color: rgba(255, 255, 255, 0.1);\n",
       "    }\n",
       "    </style>\n",
       "    \n",
       "    <div class=\"token-viz\">\n",
       "        <div class=\"header\" style=\"background-color: #555555;\">Input Sentence:</div>\n",
       "        <div class=\"sentence\">Fait: Michael Jordan joue au</div>\n",
       "        \n",
       "        <div>\n",
       "            <div class=\"header\" style=\"background-color: #2471A3;\">Original Top 5 Tokens</div>\n",
       "            <table>\n",
       "                <thead>\n",
       "                    <tr>\n",
       "                        <th class=\"token-col\">Token</th>\n",
       "                        <th class=\"prob-col\" style=\"text-align: right;\">Probability</th>\n",
       "                        <th class=\"dist-col\">Distribution</th>\n",
       "                    </tr>\n",
       "                </thead>\n",
       "                <tbody>\n",
       "    \n",
       "                    <tr class=\"even-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" basket\"> basket</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.566</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #2471A3; width: 100%;\"></div>\n",
       "                                <span class=\"bar-text\">56.6%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"odd-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" basketball\"> basketball</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.111</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #2471A3; width: 19%;\"></div>\n",
       "                                <span class=\"bar-text\">11.1%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"even-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" golf\"> golf</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.067</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #2471A3; width: 11%;\"></div>\n",
       "                                <span class=\"bar-text\">6.7%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"odd-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" baseball\"> baseball</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.041</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #2471A3; width: 7%;\"></div>\n",
       "                                <span class=\"bar-text\">4.1%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"even-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" football\"> football</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.028</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #2471A3; width: 4%;\"></div>\n",
       "                                <span class=\"bar-text\">2.8%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                </tbody>\n",
       "            </table>\n",
       "            \n",
       "            <div class=\"header\" style=\"background-color: #27AE60;\">New Top 5 Tokens</div>\n",
       "            <table>\n",
       "                <thead>\n",
       "                    <tr>\n",
       "                        <th class=\"token-col\">Token</th>\n",
       "                        <th class=\"prob-col\" style=\"text-align: right;\">Probability</th>\n",
       "                        <th class=\"dist-col\">Distribution</th>\n",
       "                    </tr>\n",
       "                </thead>\n",
       "                <tbody>\n",
       "    \n",
       "                    <tr class=\"even-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" baloncesto\"> baloncesto</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.289</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #27AE60; width: 51%;\"></div>\n",
       "                                <span class=\"bar-text\">28.9%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"odd-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" golf\"> golf</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.198</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #27AE60; width: 35%;\"></div>\n",
       "                                <span class=\"bar-text\">19.8%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"even-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" fútbol\"> fútbol</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.057</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #27AE60; width: 10%;\"></div>\n",
       "                                <span class=\"bar-text\">5.7%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"odd-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" béisbol\"> béisbol</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.044</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #27AE60; width: 7%;\"></div>\n",
       "                                <span class=\"bar-text\">4.4%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                    <tr class=\"even-row\">\n",
       "                        <td class=\"monospace token-col\" title=\" basketball\"> basketball</td>\n",
       "                        <td class=\"prob-col\" style=\"text-align: right;\">0.044</td>\n",
       "                        <td class=\"dist-col\">\n",
       "                            <div class=\"bar-container\">\n",
       "                                <div class=\"bar\" style=\"background-color: #27AE60; width: 7%;\"></div>\n",
       "                                <span class=\"bar-text\">4.4%</span>\n",
       "                            </div>\n",
       "                        </td>\n",
       "                    </tr>\n",
       "        \n",
       "                </tbody>\n",
       "            </table>\n",
       "        </div>\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "s_french = \"Fait: Michael Jordan joue au\"\n",
    "\n",
    "with torch.inference_mode():\n",
    "    original_logits = model(s_french)\n",
    "    new_logits, _ = model.feature_intervention(s_french, fr_es_intervention_tuples)\n",
    "\n",
    "display_topk_token_predictions(s_french, original_logits, new_logits)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example: Interventions + Sampling\n",
    "We've now intervened twice on the last token of the sentence; interventions on other positions work analogously. But what if we want to intervene in an open-ended fashion, allowing our model to generate tokens with that intervention still active? We can do this as follows, by setting the position of our intervention to an open-ended slice: `slice(pos, None, None)`. We set `pos` to be the last token of the original input, but you can also set it to an earlier position."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "sequence_length = len(model.tokenizer(s_spanish).input_ids)\n",
    "original_feature_pos = sequence_length - 1\n",
    "open_ended_slice = slice(original_feature_pos, None, None)\n",
    "open_ended_es_fr_intervention_tuples = [(layer, open_ended_slice, feature_idx, 0.0) for (layer, _, feature_idx) in french_supernode_features] \n",
    "open_ended_es_fr_intervention_tuples+= [(layer, open_ended_slice, feature_idx, 10*activations[layer, orig_pos, feature_idx]) for (layer, orig_pos, feature_idx) in spanish_supernode_features]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we generate by calling `feature_intervention_generate`! Make sure to set `use_past_kv_cache` to false, otherwise the model will attempt to generate using the KV cache + length=1 inputs; this is more efficient, but makes interventions hard. `do_sample` is off here for consistency, but you can turn it on as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <style>\n",
       "    .generations-viz {\n",
       "        font-family: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;\n",
       "        margin-bottom: 12px;\n",
       "        font-size: 13px;\n",
       "        max-width: 700px;\n",
       "    }\n",
       "    .generations-viz .section-header {\n",
       "        font-weight: bold;\n",
       "        font-size: 14px;\n",
       "        margin: 10px 0 5px 0;\n",
       "        padding: 4px 6px;\n",
       "        border-radius: 3px;\n",
       "        color: white;\n",
       "        display: block;\n",
       "    }\n",
       "    .generations-viz .pre-intervention-header {\n",
       "        background-color: #2471A3;\n",
       "    }\n",
       "    .generations-viz .post-intervention-header {\n",
       "        background-color: #27AE60;\n",
       "    }\n",
       "    .generations-viz .generation-container {\n",
       "        margin-bottom: 8px;\n",
       "        padding: 3px;\n",
       "        border-left: 3px solid rgba(100, 100, 100, 0.5);\n",
       "    }\n",
       "    .generations-viz .generation-text {\n",
       "        background-color: rgba(200, 200, 200, 0.2);\n",
       "        padding: 6px 8px;\n",
       "        border-radius: 3px;\n",
       "        border: 1px solid rgba(100, 100, 100, 0.5);\n",
       "        font-family: monospace;\n",
       "        font-weight: 500;\n",
       "        white-space: pre-wrap;\n",
       "        line-height: 1.2;\n",
       "        font-size: 13px;\n",
       "        overflow-x: auto;\n",
       "    }\n",
       "    .generations-viz .base-text {\n",
       "        color: rgba(100, 100, 100, 0.9);\n",
       "    }\n",
       "    .generations-viz .new-text {\n",
       "        background-color: rgba(255, 255, 0, 0.25);\n",
       "        font-weight: bold;\n",
       "        padding: 1px 0;\n",
       "        border-radius: 2px;\n",
       "    }\n",
       "    .generations-viz .pre-intervention-item {\n",
       "        border-left-color: #2471A3;\n",
       "    }\n",
       "    .generations-viz .post-intervention-item {\n",
       "        border-left-color: #27AE60;\n",
       "    }\n",
       "    .generations-viz .generation-number {\n",
       "        font-weight: bold;\n",
       "        margin-bottom: 3px;\n",
       "        color: rgba(70, 70, 70, 0.9);\n",
       "        font-size: 12px;\n",
       "    }\n",
       "    </style>\n",
       "    \n",
       "    <div class=\"generations-viz\">\n",
       "    \n",
       "    <div class=\"section-header pre-intervention-header\">Pre-intervention generations:</div>\n",
       "    \n",
       "        <div class=\"generation-container pre-intervention-item\">\n",
       "            <div class=\"generation-number\">Generation 1</div>\n",
       "            <div class=\"generation-text\"><span class=\"base-text\">Fait: Michael Jordan joue au</span><span class=\"new-text\"> basket avec son fils, Jeffrey Jordan, à la</span></div>\n",
       "        </div>\n",
       "        \n",
       "    <div class=\"section-header post-intervention-header\">Post-intervention generations:</div>\n",
       "    \n",
       "        <div class=\"generation-container post-intervention-item\">\n",
       "            <div class=\"generation-number\">Generation 1</div>\n",
       "            <div class=\"generation-text\"><span class=\"base-text\">Fait: Michael Jordan joue au</span><span class=\"new-text\"> baloncesto.\n",
       "\n",
       "Fato: Michael Jordan es un</span></div>\n",
       "        </div>\n",
       "        \n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pre_intervention_generation = [model.generate(s_french, do_sample=False, use_past_kv_cache=False, verbose=False)]\n",
    "post_intervention_generation = [model.feature_intervention_generate(s_french, open_ended_es_fr_intervention_tuples, do_sample=False, verbose=False)[0]]\n",
    "\n",
    "display_generations_comparison(s_french, pre_intervention_generation, post_intervention_generation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
